from full_system import *
from experiments_utils import *

base_folder = "/home/moemen/projects/def-guzdial/moemen/final"

dataset_name = "reddit"

if dataset_name == "reddit":
   data_folder = os.path.join(base_folder, "data_split_reddit/train")
   delta_folder = os.path.join(base_folder, "out_split_reddit/train")
   vae_path = os.path.join(base_folder, 'vae', dataset_name, "llama_vae_state_reddit.pt")
   pca_folder = os.path.join(base_folder, 'vae', dataset_name, "MultiheadPCA_reddit")
   vae_info_path = os.path.join(base_folder, 'vae', dataset_name, "vae_info.json")
else:
   data_folder = os.path.join(base_folder, "data_split/train")
   delta_folder = os.path.join(base_folder, "out_split/train")
   vae_path = os.path.join(base_folder, 'vae', dataset_name, "llama_vae_state.pt")
   pca_folder = os.path.join(base_folder, 'vae', dataset_name, "MultiheadPCA")
   vae_info_path = os.path.join(base_folder, 'vae', dataset_name, "vae_info.json")

if_pca = True

cross_entropy = True
generations = True
TOP_K = 20
TOP_P = 0.9
TEMPERATURE = 0.9
N_EXAMPLES = 10
N_TOKENS = 30 # 60


N_SAMPLES_LIST = [1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20]

routine_id = 'routine_0'

def generate_examples(model, tokenizer, prompt = "", n_examples=1, n_tokens=100):
    pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=n_tokens, do_sample=True, top_k=TOP_K, top_p=TOP_P, temperature=TEMPERATURE) #top_k=25, top_p=0.6)
    generated = []
    for i in range(n_examples):
        result = pipe(f"{prompt}")
        generated.append(result[0]['generated_text'])
    return generated

if __name__ == "__main__":
    test_data_folder = os.path.join(sys.argv[1])
    corpus_name = sys.argv[2]
    print("-------------------------------CORPUS:", corpus_name)
    curr_data_folder = os.path.join(test_data_folder, corpus_name)
    test_data_file = [file for file in os.listdir(curr_data_folder) if 'test' in file][0]
    train_data_file = [file for file in os.listdir(curr_data_folder) if 'train' in file][0]	
    test_corpus = open(os.path.join(curr_data_folder, test_data_file), 'r').read().split('\n\n')[:-1]
    train_corpus = open(os.path.join(curr_data_folder, train_data_file), 'r').read().split('\n\n')[:-1]

    my_system = DoppelWriter(data_folder, delta_folder, vae_path, vae_info_path, if_pca, pca_folder)

    corpus_routines_folder = os.path.join(routines_folder, corpus_name)

    model_names, corpus_ids = load_routine(corpus_name, routine_id)

    corpus = [train_corpus[i] for i in corpus_ids]

    bank_output_folder, interpolation_output_folder, final_output_folder = initialize_corpus(corpus_name)
    initialize_base_models(bank_output_folder, model_names)

    ### Checking if models exist

    found = check_if_models_exist(bank_output_folder, model_names, corpus_ids)

    for i in range(len(model_names)):
        model_name = model_names[i]
        if found[i] is not None:
            print("------Found", model_name, "-------")
            print("")
            continue
        
        os_deltas = []
        os_losses = []
        calc_losses = True
        lines_all = []
        print("MODEL", model_name)
        print("CORPUS", corpus_name, corpus_ids)
        for n_samples in N_SAMPLES_LIST:
            print(n_samples)
            sample_corpus = corpus[:n_samples]
            _, one_step_delta = my_system.backprob_one_model(sample_corpus, model_name)
            one_step_loss = cross_entropy_evaluation(my_system.get_tuned_model(one_step_delta), my_system.tokenizer, test_corpus)
            os_deltas.append(one_step_delta)
            os_losses.append(one_step_loss)

        save_finetuned_model_data(bank_output_folder, model_name, corpus_ids, (os_losses, os_deltas))
    

    ## Load models data

    found = check_if_models_exist(bank_output_folder, model_names, corpus_ids)
    my_data_list = []
    for folder in found:
        if folder is None:
            print("MODEL NOT FOUND")
            break
        with open(os.path.join(folder, 'finetuning_data.pkl'), 'rb') as file:
            my_data = pickle.load(file)
        my_data_list.append(my_data)
    
    os_losses, os_deltas = (collect_data(i, my_data_list, N_SAMPLES_LIST) for i in range(2))

    ### Checking if interpolation exists
    int_types = ['vanilla_linear', 'averaging']
    for int_type in int_types:
        initialize_interpolation(interpolation_output_folder, int_type)
        found_interpolation_folder = check_if_interpolation_exists(interpolation_output_folder, model_names, corpus_ids, int_type)
        if found_interpolation_folder is None:
            latent_models_all = []
            n_losses = []
            found_solution_all = []
            for i in range(len(N_SAMPLES_LIST)):
                print(i)
                latent_models, found_solution = my_system.interpolate(model_names, os_deltas, i, int_type)

                latent_models_all.append(latent_models)

                new_deltas = []
                for latent_vec in latent_models:
                    new_deltas.append(my_system.decode_latent_model(latent_vec))

                new_losses = []
                for delta in new_deltas:
                    new_losses.append(cross_entropy_evaluation(my_system.get_tuned_model(delta), my_system.tokenizer, test_corpus))
                n_losses.append(new_losses)

                found_solution_all.append(found_solution)

            save_interpolation_data(interpolation_output_folder, model_names, corpus_ids, int_type, (n_losses, latent_models_all, found_solution_all))
    
    ### Load interpolation data and Generate final results

    for int_type in int_types:
        print("Generating results for", int_type, "interpolation...")
        found_interpolation_folder = check_if_interpolation_exists(interpolation_output_folder, model_names, corpus_ids, int_type)
        if found_interpolation_folder is None:
            print("INTERPOLATION NOT FOUND")
        else:
            interpolation_id = found_interpolation_folder.split('/')[-1]
            with open(os.path.join(found_interpolation_folder, 'interpolation_data.pkl'), 'rb') as file:
                loaded_interpolation_data = pickle.load(file)
                n_losses, latent_models_all, found_solution_all = loaded_interpolation_data
        
        output_folder = initialize_final_results(final_output_folder, interpolation_output_folder, int_type, interpolation_id)

        cross_entropy_file_name = 'cross_entropy.pkl'
        generations_file_name = 'generations.pkl'

        cross_entropy_file_path = os.path.join(output_folder, cross_entropy_file_name)
        if cross_entropy and not os.path.exists(cross_entropy_file_path):
            save_pickle_file(cross_entropy_file_path, (os_losses, n_losses, found_solution_all))
            print("Saved cross entropy results!")
        else:
            print("Skipped cross entropy results :(")

        generations_file_path = os.path.join(output_folder, generations_file_name)
        if generations and not os.path.exists(generations_file_path):
            os_generations = []
            n_generations = []
            for i in range(len(N_SAMPLES_LIST)):
                print(i)
                one_step_gen = []
                new_gen = []
                for os_delta in os_deltas[i]:
                    one_step_gen.append(generate_examples(my_system.get_tuned_model(os_delta), my_system.tokenizer, n_examples=N_EXAMPLES, n_tokens=N_TOKENS))
                
                for latent_vec in latent_models_all[i]:
                    new_full_delta = my_system.decode_latent_model(latent_vec)
                    new_gen.append(generate_examples(my_system.get_tuned_model(new_full_delta), my_system.tokenizer, n_examples=N_EXAMPLES, n_tokens=N_TOKENS))
            
                os_generations.append(one_step_gen)
                n_generations.append(new_gen)

            save_pickle_file(generations_file_path, (os_generations, n_generations, found_solution_all))
            print("Saved generations results!")
        else:
            print("Skipped generations results :)")






    
    



    
